import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import ast
import yaml
import seaborn as sns
import statsmodels.formula.api as smf
import scipy
from scipy.stats import variation
import re 
import itertools

COMPLEXITIES = ['Very low', 'Low','Moderate','High','Very high']
def gini(x):
    #https://stackoverflow.com/questions/39512260/calculating-gini-coefficient-in-python-numpy
    mad = np.abs(np.subtract.outer(x, x)).mean()
    rmad = mad/np.mean(x)
    return rmad/2

def get_univariate_coef_df(df, independent_variables, outcomes, complexities, complexity_based):
    coef_df = pd.DataFrame()
    for independent_variable in independent_variables:
        for outcome in outcomes:

            if complexity_based:
                for complexity in complexities:
                    formula = "{} ~ {}".format(outcome, independent_variable)
                    model = smf.ols(formula, data=df[df.complexity==complexity]).fit()

                    err_series = model.params - model.conf_int()[0]
                    coef_df = coef_df.append(pd.DataFrame({'coef': model.params.values[1:],
                                                           'err': err_series.values[1:],
                                                           'independent_variable': re.sub("\[T.True\]", "", err_series.index.values[1]), #The replacing is for binary variables 
                                                           'complexity': complexity,
                                                           'outcome': outcome,
                                                           'R_squared': model.rsquared}))

            else: 
                formula = "{} ~ {}".format(outcome, independent_variable)
                model = smf.ols(formula, data=df.drop_duplicates(subset=["game_id"])).fit()

                err_series = model.params - model.conf_int()[0]
                coef_df = coef_df.append(pd.DataFrame({'coef': model.params.values[1:],
                                                       'err': err_series.values[1:],
                                                       'independent_variable': err_series.index.values[1],
                                                       'outcome': outcome,
                                                       'R_squared': model.rsquared}))

    return coef_df


def complexity_effect_plots(model, list_coefs, n_multiple_comparisons=None, xticklabels=None, xlabelrotation=80):
    mean = model.params[list_coefs]
    se = model.bse[list_coefs]
    if n_multiple_comparisons:
        critical_val = abs(scipy.stats.norm.ppf(0.05/(n_multiple_comparisons * 2)))
        print("Correcting for {} comparisons by using {} stdevs for 95% CI".format(n_multiple_comparisons, critical_val))
    else: 
        critical_val = 1.96
    fig=plt.errorbar(x=list_coefs, y=mean, yerr = 1.96 * se, linestyle="")
    fig=plt.errorbar(x=list_coefs, y=mean, yerr = critical_val * se, linestyle="", alpha=0.3, color="black")

    if xticklabels:
        plt.xticks(list_coefs, xticklabels)

    plt.scatter(x=list_coefs, y=mean)
    plt.axhline(y=0, linestyle="--", color="black")
    plt.xticks(rotation=xlabelrotation)
    plt.title("Dependent Variable: {}".format(model.model.endog_names))
    plt.ylabel("coef")

    return fig


def complexity_effect_plots_ax(ax, model, list_coefs, n_multiple_comparisons=None, xticklabels=None, xlabelrotation=80):
    mean = model.params[list_coefs]
    se = model.bse[list_coefs]
    if n_multiple_comparisons:
        critical_val = abs(scipy.stats.norm.ppf(0.05/(n_multiple_comparisons * 2)))
        print("Correcting for {} comparisons by using {} stdevs for 95% CI".format(n_multiple_comparisons, critical_val))
    else: 
        critical_val = 1.96
    ax.errorbar(x=list_coefs, y=mean, yerr = 1.96 * se, linestyle="", elinewidth=4)
    ax.errorbar(x=list_coefs, y=mean, yerr = critical_val * se, linestyle="", alpha=0.3, color="black", elinewidth=4)

    ax.scatter(x=list_coefs, y=mean, s=150)
    ax.axhline(y=0, linestyle="--", color="black")
    ax.set_xticklabels(labels=xticklabels, rotation=xlabelrotation)
    ax.set_title("Dependent Variable: {}".format(model.model.endog_names))
    ax.set_ylabel("coef")


def get_turn_taking_variation(consolidated_actions):
    turns_per_player = list(consolidated_actions.subject_id.value_counts())
    if len(turns_per_player) < 3: 
        turns_per_player = turns_per_player + [0]*(3-len(turns_per_player))
    return variation(turns_per_player)

def get_turn_taking_gini(consolidated_actions):
    turns_per_player = list(consolidated_actions.subject_id.value_counts())
    if len(turns_per_player) < 3: 
        turns_per_player = turns_per_player + [0]*(3-len(turns_per_player))
    return gini(turns_per_player)

def composition_vs_perf_plots(df_regressions, outcomes, variable_order, n_multiple_comp_correction=None, separating_lines=None, ind_var_labels=None, complexity=True, figsize=(15,15), exclude_from_correction=None, erroralpha=1):
    fig, axes = plt.subplots(nrows=1, ncols=len(outcomes), figsize=figsize, sharex=True, sharey=True)

    if separating_lines is not None:
            separating_lines[0] = separating_lines[0] - 1 + 0.5 #-1 because of how matplotlib counts, 0.5 to put in between points 
            
    for plot_index, outcome in enumerate(outcomes): 
        temp_df = df_regressions.query("outcome==@outcome and independent_variable in @variable_order").assign(var_order = lambda x: x['independent_variable'].apply(lambda y: variable_order.index(y)))

        if complexity==True:
            sns.pointplot(data=temp_df.sort_values("var_order", ascending=True), 
                          x="coef", y="independent_variable", hue="complexity", dodge=0.5, linestyles="",
                          hue_order=reversed(["Overall"] + COMPLEXITIES), palette=dict(zip(["Overall"] + COMPLEXITIES, ["black", "#1fad1f", "#77FF00", "#FFFF00","#FF8800", "#FF0000"])), 
                          ax=axes[plot_index])
        else:
            sns.pointplot(data=temp_df.sort_values("var_order", ascending=True), 
                  x="coef", y="independent_variable", linestyles="", 
                  ax=axes[plot_index])

        axes[plot_index].axvline(x=0, linestyle="--", color="black", alpha=0.2)
        axes[plot_index].set_title(outcome)
        
        if separating_lines is not None:
            for y_line in np.cumsum(separating_lines):
                axes[plot_index].axhline(y=y_line, color="black", alpha=1)
        
        x_coords, y_coords = ([],[])
        for point_pair in axes[plot_index].collections:
                for x, y in point_pair.get_offsets():
                    x_coords.append(x)
                    y_coords.append(y)
        errorbars = [df_regressions.query("coef == @x").err.values[0] for x in x_coords]

        overall_line_index = df_regressions.query("independent_variable != 'Intercept' and independent_variable != 'game_id Var'")['independent_variable'].nunique()

        if n_multiple_comp_correction is None: 
            axes[plot_index].errorbar(x_coords[:-overall_line_index], y_coords[:-overall_line_index],
                        xerr=errorbars[:-overall_line_index],
                        fmt=' ', zorder=-1, ecolor="grey", alpha=erroralpha)

            axes[plot_index].errorbar(x_coords[-overall_line_index:], y_coords[-overall_line_index:],
                        xerr=errorbars[-overall_line_index:],
                        fmt=' ', zorder=-1, ecolor="grey", alpha=1)


        try:
            axes[plot_index].get_legend().remove()
        except:
            pass 
            
        axes[plot_index].set_ylabel("")
        


        if n_multiple_comp_correction is not None: 
            correction_critical_value = abs(scipy.stats.norm.ppf(0.05/(n_multiple_comp_correction * 2)))
            print("Using {} as critical value instead of 1.96 for multiple comparisons correction".format(correction_critical_value)) 

            errorbars = np.array(errorbars) * correction_critical_value / 1.96 #Divide by 1.96 since the original "error" value is (mean - lower bound of 95% CI)

            if exclude_from_correction is not None:
                errorbars = [df_regressions.query("coef == @x").err.values[0] if df_regressions.query("coef == @x").independent_variable.values[0] not in exclude_from_correction else 0 for x in x_coords]
                errorbars = np.array(errorbars) * correction_critical_value / 1.96

            #Plot the wider CI corresponding to the correction
            axes[plot_index].errorbar(x_coords[:-overall_line_index], y_coords[:-overall_line_index],
                        xerr=errorbars[:-overall_line_index],
                        fmt=' ', zorder=-1, ecolor="grey", alpha=erroralpha)

            axes[plot_index].errorbar(x_coords[-overall_line_index:], y_coords[-overall_line_index:],
                        xerr=errorbars[-overall_line_index:],
                        fmt=' ', zorder=-1, ecolor="grey", alpha=1)
        
        
        axes[plot_index].tick_params(axis="both",labelsize=20)
        
        if ind_var_labels is not None: 
            axes[plot_index].set_yticklabels(labels=ind_var_labels, fontsize=20)

    return fig, axes 